//go:build sql_integration

package m191tom192

import (
	"context"
	"testing"

	"github.com/stackrox/rox/generated/storage"
	oldSchema "github.com/stackrox/rox/migrator/migrations/m_191_to_m_192_vulnerability_requests_searchable_scope/schema/old"
	previousStore "github.com/stackrox/rox/migrator/migrations/m_191_to_m_192_vulnerability_requests_searchable_scope/store/previous"
	updatedStore "github.com/stackrox/rox/migrator/migrations/m_191_to_m_192_vulnerability_requests_searchable_scope/store/updated"
	pghelper "github.com/stackrox/rox/migrator/migrations/postgreshelper"
	"github.com/stackrox/rox/migrator/types"
	"github.com/stackrox/rox/pkg/fixtures"
	"github.com/stackrox/rox/pkg/postgres/pgutils"
	"github.com/stackrox/rox/pkg/sac"
	"github.com/stackrox/rox/pkg/search"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
	"github.com/stretchr/testify/suite"
)

type migrationTestSuite struct {
	suite.Suite

	db  *pghelper.TestPostgres
	ctx context.Context
}

func TestMigration(t *testing.T) {
	suite.Run(t, new(migrationTestSuite))
}

func (s *migrationTestSuite) SetupSuite() {
	s.ctx = sac.WithAllAccess(context.Background())
	s.db = pghelper.ForT(s.T(), false)

	pgutils.CreateTableFromModel(s.ctx, s.db.GetGormDB(), oldSchema.CreateTableVulnerabilityRequestsStmt)
}

func (s *migrationTestSuite) TestMigration() {
	imageScopedReqs := []*storage.VulnerabilityRequest{
		fixtures.GetImageScopeDeferralRequest("reg-1", "remote-1", "tag-1", "cve-1"),
		fixtures.GetImageScopeDeferralRequest("reg-2", "remote-1", "tag-1", "cve-2"),
		fixtures.GetImageScopeDeferralRequest("reg-3", "remote-1", "tag-1", ""),
		fixtures.GetImageScopeDeferralRequest("reg-4", "remote-1", "", ""),
		fixtures.GetImageScopeDeferralRequest("reg-5", "", "", ""),
		fixtures.GetImageScopeDeferralRequest("reg-6", "remote-2", ".*", "cve-1"),
	}

	globalScopedReqs := []*storage.VulnerabilityRequest{
		fixtures.GetGlobalFPRequest("cve-1"),
		fixtures.GetGlobalFPRequest(""),
		fixtures.GetGlobalDeferralRequest("cve-1"),
		fixtures.GetGlobalDeferralRequest(""),
	}
	var ids []string
	for _, req := range imageScopedReqs {
		ids = append(ids, req.GetId())
	}
	for _, req := range globalScopedReqs {
		ids = append(ids, req.GetId())
	}

	store := previousStore.New(s.db)
	require.NoError(s.T(), store.UpsertMany(s.ctx, imageScopedReqs))
	require.NoError(s.T(), store.UpsertMany(s.ctx, globalScopedReqs))

	dbs := &types.Databases{
		GormDB:     s.db.GetGormDB(),
		PostgresDB: s.db.DB,
		DBCtx:      s.ctx,
	}

	s.Require().NoError(migration.Run(dbs))

	newStore := updatedStore.New(s.db)
	result, err := newStore.GetByQuery(s.ctx, search.EmptyQuery())
	assert.NoError(s.T(), err)
	assert.ElementsMatch(s.T(), ids, collectIDs(result...))

	result, err = newStore.GetByQuery(s.ctx,
		search.NewQueryBuilder().AddExactMatches(search.ImageRegistryScope, "reg-1").ProtoQuery())
	assert.NoError(s.T(), err)
	assert.ElementsMatch(s.T(), collectIDs(imageScopedReqs[0]), collectIDs(result...))

	result, err = newStore.GetByQuery(s.ctx,
		search.NewQueryBuilder().AddExactMatches(search.ImageRemoteScope, "remote-1").ProtoQuery())
	assert.NoError(s.T(), err)
	assert.ElementsMatch(
		s.T(),
		collectIDs(imageScopedReqs[0], imageScopedReqs[1], imageScopedReqs[2], imageScopedReqs[3]),
		collectIDs(result...),
	)

	result, err = newStore.GetByQuery(s.ctx,
		search.NewQueryBuilder().AddExactMatches(search.ImageTagScope, "tag-1").ProtoQuery())
	assert.NoError(s.T(), err)
	assert.ElementsMatch(
		s.T(),
		collectIDs(imageScopedReqs[0], imageScopedReqs[1], imageScopedReqs[2]),
		collectIDs(result...),
	)

	result, err = newStore.GetByQuery(s.ctx,
		search.NewQueryBuilder().AddExactMatches(search.ImageTagScope, "").ProtoQuery())
	assert.NoError(s.T(), err)
	assert.ElementsMatch(
		s.T(),
		collectIDs(imageScopedReqs[3], imageScopedReqs[4]),
		collectIDs(result...),
	)

	result, err = newStore.GetByQuery(s.ctx,
		search.NewQueryBuilder().AddNullField(search.ImageRemoteScope).ProtoQuery())
	assert.NoError(s.T(), err)
	assert.ElementsMatch(
		s.T(),
		collectIDs(globalScopedReqs...),
		collectIDs(result...),
	)
}

func collectIDs(reqs ...*storage.VulnerabilityRequest) []string {
	var ids []string
	for _, req := range reqs {
		ids = append(ids, req.GetId())
	}
	return ids
}
